[test] test_topk_plain: parametrize sweep to fix collection-time OOM#3934
[test] test_topk_plain: parametrize sweep to fix collection-time OOM#3934JohnQinAMD wants to merge 1 commit into
Conversation
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
There was a problem hiding this comment.
Pull request overview
This PR refactors op_tests/test_topk_plain.py to avoid collection/import-time GPU OOMs caused by a module-level sweep that allocates tensors for all cases up front. It converts the sweep into isolated, lazy-per-case execution using pytest parameterization, adds teardown to free GPU memory between cases, and preserves the end-of-run performance summary.
Changes:
- Replace the module-level sweep loop with
@pytest.mark.parametrizetest cases to prevent collection-time allocations and speed up collection. - Add an autouse fixture to GC +
torch.cuda.empty_cache()between cases and a session-scoped fixture to emit a single markdown perf summary (with a safe fallback whentabulateis missing). - Vectorize the input permutation generation and reduce iteration counts (1000 → 20) to keep this as a correctness-focused test.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
b427fd5 to
7c7c7f6
Compare
CI runs each op test as a script (`python3 <file>` in aiter_test.sh), not via pytest. test_topk_plain.py sweeps 84 cases up to 3072x131072 fp32; with no cleanup between cases, a memory-pressured gfx950 runner OOMs mid-sweep and the whole file exits non-zero -> intermittent shard failures unrelated to the code under test. Fix (keeps the `python3 <file>` contract): - free each case's tensors (del + gc + torch.cuda.empty_cache()) so peak memory is one case, not the whole sweep - guard the sweep under `if __name__ == "__main__":` (clean import) - num_iters 1000 -> 100 (correctness check, not a perf gate; was the slowest file in its shard) - vectorize the per-row permutation (drops a batch_size-long Python loop) - assert on checkAllclose's returned error ratio (it does NOT raise) so an incorrect topk_plain actually fails CI instead of silently passing - summary table falls back to df.to_string when `tabulate` is absent Validated: `python3 op_tests/test_topk_plain.py` runs all 84 cases, all pass (err 0), summary prints, exit 0 on MI355X (gfx950). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
7c7c7f6 to
9a2655e
Compare
zufayu
left a comment
There was a problem hiding this comment.
LGTM — the change is a test quality fix (OOM prevention, add assert on error ratio, reduce iter count). The single CI failure () is a pre-existing issue on main (same in , unrelated to this PR).
zufayu
left a comment
There was a problem hiding this comment.
LGTM. The change is a test quality fix: OOM prevention, adds assert on error ratio, reduces iter count. The single CI failure (Standard Tests MI35X shard 6) is a pre-existing issue on main — same RuntimeError in test_moe_2stage.py, unrelated to this PR.
|
@valarLip Could you merge this when you get a chance? The CI failure (Standard Tests MI35X shard 6) is a pre-existing issue on main branch (RuntimeError in test_moe_2stage.py, same failure exists on main today), unrelated to this PR's changes. |
Problem
CI runs each op test as a script (
python3 <file>in.github/scripts/aiter_test.sh), not via pytest.test_topk_plain.pysweeps 84 cases up to3072 × 131072fp32 and never frees tensors between cases, so on a memory-pressured gfx950 runner it OOMs mid-sweep and the whole file exits non-zero — producing intermittent "Standard Tests (MI35X, 8, 5)" failures unrelated to the code under test. On a clean GPU it passes; the failure is purely residual-memory dependent.Fix (keeps the
python3 <file>contract)del+gc.collect()+torch.cuda.empty_cache()) so peak memory is one case, not the whole sweep — the actual OOM fix.if __name__ == "__main__":so import is side-effect-free.num_iters1000 → 100 — this is a correctness check, not a perf gate; the high iteration count made it the slowest file in its shard for no benefit.batch_size-long Python loop).checkAllclose's returned error ratio — it does not raise, so the test previously would have passed even iftopk_plainwere incorrect.df.to_stringwhentabulateis absent, so the informational summary can never fail the run.Validation
python3 op_tests/test_topk_plain.pyon MI355X (gfx950): all 84 cases run, allpassed~(err 0), summary prints, exit 0.Addresses both Copilot review comments (script-style no-op under
python3; assert on the error ratio).Motivation
Technical Details
Test Plan
Test Result
Submission Checklist